/*
 * Lincheck
 *
 * Copyright (C) 2019 - 2025 JetBrains s.r.o.
 *
 * This Source Code Form is subject to the terms of the
 * Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed
 * with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
 */

package org.jetbrains.kotlinx.lincheck.transformation.transformers

import org.jetbrains.kotlinx.lincheck.transformation.ASM_API
import org.objectweb.asm.ConstantDynamic
import org.objectweb.asm.Label
import org.objectweb.asm.MethodVisitor
import org.objectweb.asm.Opcodes

/**
 * Filters out chunks of bytecode generated by coverage. Class highly relies on instructions inserted by coverage agent.
 *
 * @param initialVisitor initial method visitor to delegate work to when coverage-related bytecode encountered.
 * @param methodVisitor method visitor to delegate calls to when bytecode is not related to coverage.
 *
 * @see <a href="https://github.com/JetBrains/intellij-coverage/blob/master/instrumentation/src/com/intellij/rt/coverage/instrumentation/InstrumentationUtils.java#L41">Intellij-coverage instrumentation</a>
 */
internal class CoverageBytecodeFilter(
    private val initialVisitor: MethodVisitor,
    methodVisitor: MethodVisitor,
) : MethodVisitor(ASM_API, methodVisitor) {
    enum class State {
        INITIAL,
        HITS_INIT,
        HITS_INIT_FIELD,
        HITS_IN_LOCAL,
        HITS_BEFORE_ASSIGN,
    }

    private val COVERAGE_HITS_NAME = "__\$hits\$__"
    /** @see <a href="https://github.com/JetBrains/intellij-coverage/blob/master/test-kotlin/resources/bytecode/simple/branches">Hits masks static methods</a> */
    private val COVERAGE_HITS_METHOD_SIGNATURES = listOf(
        "com/intellij/rt/coverage/instrumentation/CoverageRuntime.getHitsMask(Ljava/lang/String;)[Z",
        "com/intellij/rt/coverage/instrumentation/CoverageRuntime.getHitsMaskCached(Ljava/lang/String;)[Z",
        "com/intellij/rt/coverage/instrumentation/CoverageRuntime.getHits(Ljava/lang/String;)[I",
        "com/intellij/rt/coverage/instrumentation/CoverageRuntime.getHitsCached(Ljava/lang/String;)[I",
    )

    private var state = State.INITIAL
    private var localVariableIndex = -1

    override fun visitLdcInsn(value: Any) {
        if (state == State.INITIAL && value is ConstantDynamic && value.name == COVERAGE_HITS_NAME) {
            state = State.HITS_INIT
        }
        super.visitLdcInsn(value)
    }

    override fun visitFieldInsn(opcode: Int, owner: String, name: String, desc: String) {
        if (name.contains(COVERAGE_HITS_NAME)) {
            when (opcode) {
                Opcodes.GETSTATIC -> if (state == State.INITIAL) {
                    state = State.HITS_INIT_FIELD
                }
                Opcodes.PUTSTATIC -> if (state == State.HITS_INIT_FIELD) {
                    state = State.INITIAL
                }
            }

            initialVisitor.visitFieldInsn(opcode, owner, name, desc)
            return
        }
        super.visitFieldInsn(opcode, owner, name, desc)
    }

    override fun visitJumpInsn(opcode: Int, label: Label?) {
        if (opcode == Opcodes.IFNONNULL && state == State.HITS_INIT_FIELD) {
            state = State.INITIAL
            initialVisitor.visitJumpInsn(opcode, label)
            return
        }
        super.visitJumpInsn(opcode, label)
    }

    override fun visitInsn(opcode: Int) {
        when (opcode) {
            Opcodes.BASTORE, Opcodes.IASTORE -> if (state == State.HITS_BEFORE_ASSIGN) {
                state = State.HITS_IN_LOCAL
                initialVisitor.visitInsn(opcode)
                return
            }
            Opcodes.IALOAD -> if (state == State.HITS_BEFORE_ASSIGN) {
                initialVisitor.visitInsn(opcode)
                return
            }
        }

        super.visitInsn(opcode)
    }

    override fun visitVarInsn(opcode: Int, index: Int) {
        when (opcode) {
            Opcodes.ASTORE -> {
                if (state == State.HITS_INIT || state == State.HITS_INIT_FIELD) {
                    state = State.HITS_IN_LOCAL
                    localVariableIndex = index
                }
            }
            Opcodes.ALOAD -> if (state == State.HITS_IN_LOCAL && localVariableIndex == index) {
                state = State.HITS_BEFORE_ASSIGN
            }
        }

        super.visitVarInsn(opcode, index)
    }

    override fun visitMethodInsn(
        opcode: Int,
        owner: String,
        name: String,
        descriptor: String,
        isInterface: Boolean
    ) {
        val methodSignature = "$owner.$name$descriptor"
        if (
            opcode == Opcodes.INVOKESTATIC &&
            COVERAGE_HITS_METHOD_SIGNATURES.contains(methodSignature)
        ) {
            if (state == State.INITIAL) {
                state = State.HITS_INIT_FIELD
            }

            initialVisitor.visitMethodInsn(opcode, owner, name, descriptor, isInterface)
            return
        }

        super.visitMethodInsn(opcode, owner, name, descriptor, isInterface)
    }

    override fun visitEnd() {
        state = State.INITIAL
        localVariableIndex = -1

        super.visitEnd()
    }
}